#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Runtime type checking support.

For internal use only; no backwards-compatibility guarantees.
"""

# pytype: skip-file

import inspect
import sys
import types
from collections import abc

from apache_beam import pipeline
from apache_beam.pvalue import TaggedOutput
from apache_beam.transforms import core
from apache_beam.transforms.core import DoFn
from apache_beam.transforms.window import TimestampedValue
from apache_beam.transforms.window import WindowedValue
from apache_beam.typehints.decorators import GeneratorWrapper
from apache_beam.typehints.decorators import TypeCheckError
from apache_beam.typehints.decorators import _check_instance_type
from apache_beam.typehints.decorators import getcallargs_forhints
from apache_beam.typehints.typehints import CompositeTypeHintError
from apache_beam.typehints.typehints import SimpleTypeHintError
from apache_beam.typehints.typehints import check_constraint
from apache_beam.typehints.typehints import normalize


class AbstractDoFnWrapper(DoFn):
  """An abstract class to create wrapper around DoFn"""
  def __init__(self, dofn):
    super().__init__()
    self.dofn = dofn

  def __getattribute__(self, name):
    if (name.startswith('_') or name in self.__dict__ or
        hasattr(type(self), name)):
      return object.__getattribute__(self, name)
    else:
      return getattr(self.dofn, name)

  def _inspect_start_bundle(self):
    return self.dofn.get_function_arguments('start_bundle')

  def _inspect_process(self):
    return self.dofn.get_function_arguments('process')

  def _inspect_finish_bundle(self):
    return self.dofn.get_function_arguments('finish_bundle')

  def wrapper(self, method, args, kwargs):
    return method(*args, **kwargs)

  def setup(self):
    return self.dofn.setup()

  def start_bundle(self, *args, **kwargs):
    return self.wrapper(self.dofn.start_bundle, args, kwargs)

  def process(self, *args, **kwargs):
    return self.wrapper(self.dofn.process, args, kwargs)

  def finish_bundle(self, *args, **kwargs):
    return self.wrapper(self.dofn.finish_bundle, args, kwargs)

  def teardown(self):
    return self.dofn.teardown()


class OutputCheckWrapperDoFn(AbstractDoFnWrapper):
  """A DoFn that verifies against common errors in the output type."""
  def __init__(self, dofn, full_label):
    super().__init__(dofn)
    self.full_label = full_label

  def wrapper(self, method, args, kwargs):
    try:
      result = method(*args, **kwargs)
    except TypeCheckError as e:
      # TODO(BEAM-10710): Remove the 'ParDo' prefix for the label name
      error_msg = (
          'Runtime type violation detected within ParDo(%s): '
          '%s' % (self.full_label, e))
      _, _, tb = sys.exc_info()
      raise TypeCheckError(error_msg).with_traceback(tb)
    else:
      return self._check_type(result)

  @staticmethod
  def _check_type(output):
    if output is None:
      return output

    elif isinstance(output, (dict, bytes, str)):
      object_type = type(output).__name__
      raise TypeCheckError(
          'Returning a %s from a ParDo or FlatMap is '
          'discouraged. Please use list("%s") if you really '
          'want this behavior.' % (object_type, output))
    elif not isinstance(output, abc.Iterable):
      raise TypeCheckError(
          'FlatMap and ParDo must return an '
          'iterable. %s was returned instead.' % type(output))
    return output


class TypeCheckWrapperDoFn(AbstractDoFnWrapper):
  """A wrapper around a DoFn which performs type-checking of input and output.
  """
  def __init__(self, dofn, type_hints, label=None):
    super().__init__(dofn)
    self._process_fn = self.dofn._process_argspec_fn()
    if type_hints.input_types:
      input_args, input_kwargs = type_hints.input_types
      self._input_hints = getcallargs_forhints(
          self._process_fn, *input_args, **input_kwargs)
    else:
      self._input_hints = None
    # TODO(robertwb): Multi-output.
    self._output_type_hint = type_hints.simple_output_type(label)

  def wrapper(self, method, args, kwargs):
    result = method(*args, **kwargs)
    return self._type_check_result(result)

  def process(self, *args, **kwargs):
    if self._input_hints:
      actual_inputs = inspect.getcallargs(self._process_fn, *args, **kwargs)  # pylint: disable=deprecated-method
      for var, hint in self._input_hints.items():
        if hint is actual_inputs[var]:
          # self parameter
          continue
        _check_instance_type(hint, actual_inputs[var], var, True)
    return self._type_check_result(self.dofn.process(*args, **kwargs))

  def _type_check_result(self, transform_results):
    if self._output_type_hint is None or transform_results is None:
      return transform_results

    def type_check_output(o):
      if isinstance(o, TimestampedValue) and hasattr(o, "__orig_class__"):
        # when a typed TimestampedValue is set, check the value type
        x = o.value
        # per https://stackoverflow.com/questions/57706180/,
        # __orig_class__ is te safe way to obtain the actual type
        # from from Generic[T], supported since Python 3.5.3
        beam_type = normalize(o.__orig_class__.__args__[0])
        self.type_check(beam_type, x, is_input=False)
      else:
        # TODO(robertwb): Multi-output.
        x = o.value if isinstance(o, (TaggedOutput, WindowedValue)) else o

        self.type_check(self._output_type_hint, x, is_input=False)

    # If the return type is a generator, then we will need to interleave our
    # type-checking with its normal iteration so we don't deplete the
    # generator initially just by type-checking its yielded contents.
    if isinstance(transform_results, types.GeneratorType):
      return GeneratorWrapper(transform_results, type_check_output)
    for o in transform_results:
      type_check_output(o)
    return transform_results

  @staticmethod
  def type_check(type_constraint, datum, is_input):
    """Typecheck a PTransform related datum according to a type constraint.

    This function is used to optionally type-check either an input or an output
    to a PTransform.

    Args:
        type_constraint: An instance of a typehints.TypeContraint, one of the
          white-listed builtin Python types, or a custom user class.
        datum: An instance of a Python object.
        is_input: True if 'datum' is an input to a PTransform's DoFn. False
          otherwise.

    Raises:
      TypeError: If 'datum' fails to type-check according to 'type_constraint'.
    """
    datum_type = 'input' if is_input else 'output'

    try:
      check_constraint(type_constraint, datum)
    except CompositeTypeHintError as e:
      _, _, tb = sys.exc_info()
      raise TypeCheckError(e.args[0]).with_traceback(tb)
    except SimpleTypeHintError:
      error_msg = (
          "According to type-hint expected %s should be of type %s. "
          "Instead, received '%s', an instance of type %s." %
          (datum_type, type_constraint, datum, type(datum)))
      _, _, tb = sys.exc_info()
      raise TypeCheckError(error_msg).with_traceback(tb)


class TypeCheckCombineFn(core.CombineFn):
  """A wrapper around a CombineFn performing type-checking of input and output.
  """
  def __init__(self, combinefn, type_hints, label=None):
    self._combinefn = combinefn
    self._input_type_hint = type_hints.input_types
    self._output_type_hint = type_hints.simple_output_type(label)
    self._label = label

  def setup(self, *args, **kwargs):
    self._combinefn.setup(*args, **kwargs)

  def create_accumulator(self, *args, **kwargs):
    return self._combinefn.create_accumulator(*args, **kwargs)

  def add_input(self, accumulator, element, *args, **kwargs):
    if self._input_type_hint:
      try:
        _check_instance_type(
            self._input_type_hint[0][0].tuple_types[1],
            element,
            'element',
            True)
      except TypeCheckError as e:
        error_msg = (
            'Runtime type violation detected within %s: '
            '%s' % (self._label, e))
        _, _, tb = sys.exc_info()
        raise TypeCheckError(error_msg).with_traceback(tb)
    return self._combinefn.add_input(accumulator, element, *args, **kwargs)

  def merge_accumulators(self, accumulators, *args, **kwargs):
    return self._combinefn.merge_accumulators(accumulators, *args, **kwargs)

  def compact(self, accumulator, *args, **kwargs):
    return self._combinefn.compact(accumulator, *args, **kwargs)

  def extract_output(self, accumulator, *args, **kwargs):
    result = self._combinefn.extract_output(accumulator, *args, **kwargs)
    if self._output_type_hint:
      try:
        _check_instance_type(
            self._output_type_hint.tuple_types[1], result, None, True)
      except TypeCheckError as e:
        error_msg = (
            'Runtime type violation detected within %s: '
            '%s' % (self._label, e))
        _, _, tb = sys.exc_info()
        raise TypeCheckError(error_msg).with_traceback(tb)
    return result

  def teardown(self, *args, **kwargs):
    self._combinefn.teardown(*args, **kwargs)


class TypeCheckVisitor(pipeline.PipelineVisitor):

  _in_combine = False

  def enter_composite_transform(self, applied_transform):
    if isinstance(applied_transform.transform, core.CombinePerKey):
      self._in_combine = True
      self._wrapped_fn = applied_transform.transform.fn = TypeCheckCombineFn(
          applied_transform.transform.fn,
          applied_transform.transform.get_type_hints(),
          applied_transform.full_label)

  def leave_composite_transform(self, applied_transform):
    if isinstance(applied_transform.transform, core.CombinePerKey):
      self._in_combine = False

  def visit_transform(self, applied_transform):
    transform = applied_transform.transform
    if isinstance(transform, core.ParDo):
      if self._in_combine:
        if isinstance(transform.fn, core.CombineValuesDoFn):
          transform.fn.combinefn = self._wrapped_fn
      else:
        transform.fn = transform.dofn = OutputCheckWrapperDoFn(
            TypeCheckWrapperDoFn(
                transform.fn,
                transform.get_type_hints(),
                applied_transform.full_label),
            applied_transform.full_label)


class PerformanceTypeCheckVisitor(pipeline.PipelineVisitor):
  def visit_transform(self, applied_transform):
    transform = applied_transform.transform
    full_label = applied_transform.full_label

    # Store output type hints in current transform
    output_type_hints = self.get_output_type_hints(transform)
    if output_type_hints:
      transform._add_type_constraint_from_consumer(
          full_label, output_type_hints)

    # Store input type hints in producer transform
    input_type_hints = self.get_input_type_hints(transform)
    if input_type_hints and len(applied_transform.inputs):
      producer = applied_transform.inputs[0].producer
      if producer:
        producer.transform._add_type_constraint_from_consumer(
            full_label, input_type_hints)

  def get_input_type_hints(self, transform):
    type_hints = transform.get_type_hints()

    input_types = None
    if type_hints.input_types:
      normal_hints, kwarg_hints = type_hints.input_types
      if kwarg_hints:
        input_types = kwarg_hints
      if normal_hints:
        input_types = normal_hints

    parameter_name = 'Unknown Parameter'
    if hasattr(transform, 'fn'):
      try:
        argspec = inspect.getfullargspec(transform.fn._process_argspec_fn())
      except TypeError:
        # An unsupported callable was passed to getfullargspec
        pass
      else:
        if len(argspec.args):
          arg_index = 0
          if argspec.args[0] == 'self' and len(argspec.args) > 1:
            arg_index = 1
          parameter_name = argspec.args[arg_index]
          if isinstance(input_types, dict):
            input_types = (input_types[argspec.args[arg_index]], )

    if input_types and len(input_types):
      input_types = input_types[0]

    return parameter_name, input_types

  def get_output_type_hints(self, transform):
    type_hints = transform.get_type_hints()

    output_types = None
    if type_hints.output_types:
      normal_hints, kwarg_hints = type_hints.output_types
      if kwarg_hints:
        output_types = kwarg_hints
      if normal_hints:
        output_types = normal_hints

    if output_types and len(output_types):
      output_types = output_types[0]

    return None, output_types
